"""General-purpose test script for image-to-image translation.

Once you have trained your model with train.py, you can use this script to test the model.
It will load a saved model from --checkpoints_dir and save the results to --results_dir.

It first creates model and dataset given the option. It will hard-code some parameters.
It then runs inference for --num_test images and save results to an HTML file.

Example (You need to train models first or download pre-trained models from our website):
    Test a CycleGAN model (both sides):
        python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan

    Test a CycleGAN model (one side only):
        python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout

    The option '--model test' is used for generating CycleGAN results only for one side.
    This option will automatically set '--dataset_mode single', which only loads the images from one set.
    On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
    which is sometimes unnecessary. The results will be saved at ./results/.
    Use '--results_dir <directory_path_to_save_result>' to specify the results directory.

    Test a pix2pix model:
        python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA

See options/base_options.py and options/test_options.py for more test options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from util.visualizer import save_images
from util import html
import util.util as util
from torch_receptive_field import receptive_field, receptive_field_for_unit
from PIL import Image
import sys
from torchvision.transforms import transforms
import numpy as np
import matplotlib
matplotlib.use('AGG')
import matplotlib.pyplot as plt
import seaborn as sns; sns.set_theme()
import torch
sns.light_palette("seagreen", as_cmap=True)
import torchvision

if __name__ == '__main__':
    opt = TestOptions().parse()  # get test options
    # hard-code some parameters for test
    opt.num_threads = 0   # test code only supports num_threads = 1
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    train_dataset = create_dataset(util.copyconf(opt, phase="train"))
    model = create_model(opt)      # create a model given opt.model and other options
    # create a webpage for viewing the results
    web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch))  # define the website directory
    print('creating web directory', web_dir)
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
    for i, data in enumerate(dataset):
        break
    model.data_dependent_initialize(data)
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    model.parallelize()
    if opt.eval:
        model.eval()
    receptive_field_dict = receptive_field(model.netG, (3, 256, 256))
    print(receptive_field_dict)
    """
    id = '9'
    path = '../../DA/data/cityscapes/testB/%s_B.jpg' % id
    path_base = 'checkpoints/BaseGAN/cityscapes_BtoA/fake/%s_B.png' % id
    path_my = 'checkpoints/DetachA/cityscapes_BtoA/var0.01_np256_nb1_nl0_nd10_lr0.001_ema0.998/fake/%s_B.png' % id
    path_negcut = '../NEGCUT/results/Cityscapes_pretrained/test_latest/images/fake_B/%s_B.png' % id
    path_qs = '../query-selected-attention/results/cityscapes_local_global/test_latest/images/fake_B/%s_B.png' % id
    path_truth = '../../DA/data/cityscapes/testA/%s_A.jpg' % id
    """
    id = "3842"
    path = '../../DA/data/cat2dog/testA/pixabay_cat_00%s.jpg' % id
    path_base = 'checkpoints/BaseGAN/cat2dog_AtoB/fake/pixabay_cat_00%s.png' % id
    path_my = 'checkpoints/DetachA/cat2dog_AtoB/var0.01_np256_nb1_nl0_nd10_lr0.001_ema0.998/fake/pixabay_cat_00%s.png' % id
    path_truth = path_my
    """
    id = '16101'
    path = '../../DA/data/selfie2anime/testA/female_%s.jpg' % id
    path_base = 'checkpoints/BaseGAN/selfie2anime_AtoB/fake/female_%s.png' % id
    path_my = 'checkpoints/DetachA/selfie2anime_AtoB/var0.01_np256_nb1_nl0_nd10_lr0.001_ema0.998/fake/female_%s.png' % id
    path_truth = path_my
    """
    """
    id = '120'
    path = '../../DA/data/horse2zebra/testA/n02381460_%s.jpg' % id
    path_base = 'checkpoints/BaseGAN/horse2zebra_AtoB/fake/n02381460_%s.png' % id
    path_my = 'checkpoints/DetachA/horse2zebra_AtoB/var0.01_np256_nb1_nl0_nd10_lr0.001_ema0.998/fake/n02381460_%s.png' % id
    path_truth = path_my
    """
    cmap = sns.diverging_palette(0, 10, n=9)
    img = Image.open(path).convert('RGB')
    test_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    img = test_transform(img)
    img = img.unsqueeze(0)
    fake_B, patches_input = model.netG(img, layers=model.var_layers)
    layer_id = 0
    layer_ids = ["1", "5", "13", '22', "54"]
    model.optimizer_F.swap()
    for feat_id, feat in enumerate(patches_input):
        x = np.zeros([256, 256])
        B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
        feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2).flatten(0,1)
        estimator = getattr(model.netF_A.module, 'flow_%d' % feat_id)
        densities = estimator.log_probs(feat_reshape)
        for i in range(H):
            for j in range(W):
                feat = feat_reshape[i*W+j]
                rf_range = receptive_field_for_unit(receptive_field_dict, layer_ids[layer_id], (i, j))
                #print(feat_id, rf_range)
                x[max(int(rf_range[0][0]), 0):int(rf_range[0][1]), max(int(rf_range[1][0]),0):int(rf_range[1][1])] += densities[i*W+j].item()
        x = x[6:-6,8:-8]
        print(x.shape)
        sns.heatmap(2*np.exp((x-x.mean())/(x.std())), xticklabels=False, yticklabels=False, cbar=False, cmap='coolwarm', square=True)
        layer_id += 1
        #plt.imshow(x, cmap="PuOr", interpolation='nearest')
        plt.tight_layout()
        plt.savefig('imgs/input_ema_red_%d.jpg' % feat_id, bbox_inches='tight')
        plt.close()
        print(' img saved !')


    img = Image.open(path_truth).convert('RGB')
    test_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    img = test_transform(img)
    img = img.unsqueeze(0)
    fake_B, patches_input = model.netG(img, layers=model.var_layers)
    layer_id = 0
    all_densities = {}
    for feat_id, feat in enumerate(patches_input):
        rf_range = receptive_field_for_unit(receptive_field_dict, layer_ids[layer_id], (0, 0))
        print(rf_range)
        x = np.zeros([256, 256])
        B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
        feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2).flatten(0, 1)
        estimator = getattr(model.netF_A.module, 'flow_%d' % feat_id)
        densities = estimator.log_probs(feat_reshape)
        for i in range(H):
            for j in range(W):
                feat = feat_reshape[i * W + j]
                rf_range = receptive_field_for_unit(receptive_field_dict, layer_ids[layer_id], (i, j))
                x[int(rf_range[0][0]):int(rf_range[0][1]), int(rf_range[1][0]):int(rf_range[1][1])] += densities[
                    i * W + j].item()
        all_densities[feat_id] = x

    for name,path in zip(['base', 'my', 'truth'], [path_base, path_my,  path_truth]):
        print(name, path)
        img = Image.open(path).convert('RGB')
        test_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        img = test_transform(img)
        img = img.unsqueeze(0)
        fake_B, patches_input = model.netG(img, layers=model.var_layers)
        layer_id = 0
        for feat_id, feat in enumerate(patches_input):
            rf_range = receptive_field_for_unit(receptive_field_dict, layer_ids[layer_id], (0, 0))
            print(rf_range)
            x = np.zeros([256, 256])
            B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
            feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2).flatten(0, 1)
            estimator = getattr(model.netF_A.module, 'flow_%d' % feat_id)
            densities = estimator.log_probs(feat_reshape)
            for i in range(H):
                for j in range(W):
                    feat = feat_reshape[i * W + j]
                    rf_range = receptive_field_for_unit(receptive_field_dict, layer_ids[layer_id], (i, j))
                    x[int(rf_range[0][0]):int(rf_range[0][1]), int(rf_range[1][0]):int(rf_range[1][1])] += densities[
                        i * W + j].item()
            #sns.heatmap(np.exp((x-all_densities[feat_id].mean().item())/all_densities[feat_id].std().item()), vmax=10)
            sns.heatmap(np.exp((x-x.mean())/(x.std())), vmin=0, vmax=3., xticklabels=False, yticklabels=False, cbar=False,
                        cmap='coolwarm')
            #sns.heatmap(x)
            layer_id += 1
            # plt.imshow(x, cmap="PuOr", interpolation='nearest')
            plt.tight_layout()
            plt.savefig('imgs/%s_ema_red_%d.jpg' % (name, feat_id), bbox_inches='tight')
            plt.close()
            print(' img saved !')





